package com.heima.lock.concurrent;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.FutureTask;
import java.util.stream.Collectors;

public class Service {

    //要求：200ms内查出1000个人的信息
    public Map<Long, String> get(List<Long> userIds) {
        Map<Long, String> allResultMap = new HashMap<>();

        try {
            int userCount = userIds.size();
            int threadCount = userCount % 50 == 0 ? userCount / 50 : (userCount / 50) + 1;

            //定义线程计数器： 20把门栓
            CountDownLatch latch = new CountDownLatch(threadCount);

            List<FutureTask<Map<Long, String>>> tasks = new ArrayList<>();

            for (int i = 1, j = 1; i < userCount; i += 50, j++) {
                //数据分片
                List<Long> userFor50 = userIds.stream().limit(j * 50).skip((j - 1) * 50).collect(Collectors.toList());

                Callable<Map<Long, String>> callable = new Callable() {
                    @Override
                    public Map<Long, String> call() throws Exception {
                        UserService userService = new UserService();
                        Map<Long, String> userMap = userService.getUserMap(userFor50);
                        latch.countDown();      //解开一把门栓
                        return userMap;
                    }
                };

                FutureTask<Map<Long, String>> task = new FutureTask<Map<Long, String>>(callable);

                Thread t1 = new Thread(task);
                t1.start();
                tasks.add(task);
            }


            //期望：下面获取结果的代码什么时候执行： 上面20个线程都执行完了再执行
            latch.await();          // 基于门栓计数阻塞： 当所有门栓都下掉，也就是计数器为0 时，才会取消阻塞。
            for (int i = 0; i < tasks.size(); i++) {
                FutureTask<Map<Long, String>> task = tasks.get(i);
                Map<Long, String> map = task.get();
                allResultMap.putAll(map);
            }


        }catch (Exception e){
            e.printStackTrace();
        }

        return allResultMap;
    }

    public static void main(String[] args) {
        List<Long> userIds = new ArrayList<>();
        for (int i = 1; i <= 1000; i++) {
            userIds.add(Long.valueOf(i));
        }

        long start = System.currentTimeMillis();
        Service service = new Service();
        Map<Long, String> userMap = service.get(userIds);
        long end = System.currentTimeMillis();
        System.out.println(end - start);

        System.out.println(userMap);
    }
}

class UserService {

    //根据用户id查询用户昵称信息：查50个人需要消耗100ms左右
    public Map<Long, String> getUserMap(List<Long> userIds) {
        if (userIds == null || userIds.size() > 50) {
            throw new RuntimeException("userids more than 50");
        }
        Map<Long, String> result = new HashMap();
        for (Long userId : userIds) {
            result.put(userId, "test");
        }
        try {
            Thread.sleep(90);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        return result;
    }
}

